"""
Helper for selecting a binary mask at a desired coverage level from a
positive scalar field.  The function defined here performs a simple
binary search on the quantile of non‑zero values to achieve a target
coverage between ``target_low`` and ``target_high``.  If it cannot
converge within ``max_iters`` iterations, it returns the mask from
whichever bound is closer to the mid‑point of the target interval.

This is used by the compact curvature translator to pick an S⁺ mask
whose area falls within a reasonable band (e.g. 0.5–5 % of pixels).
"""

from __future__ import annotations
import numpy as np
from typing import Tuple, Optional

def mask_with_target_coverage(pos_map: np.ndarray,
                              target_low: float = 0.005,
                              target_high: float = 0.05,
                              max_iters: int = 12) -> Tuple[np.ndarray, Optional[float], float]:
    """
    Produce a binary mask from ``pos_map`` whose fractional coverage
    lies between ``target_low`` and ``target_high``.  Only positive
    values of ``pos_map`` are considered when computing the quantile.

    Parameters
    ----------
    pos_map : np.ndarray
        A non‑negative 2D array (e.g. sum of positive LoG responses).
    target_low : float
        Lower bound of the desired coverage (fraction of total pixels,
        not percentage).  Defaults to 0.005 (0.5 %).
    target_high : float
        Upper bound of the desired coverage (fraction).  Defaults to
        0.05 (5 %).
    max_iters : int
        Maximum number of bisection iterations.  Defaults to 12.

    Returns
    -------
    mask : np.ndarray
        A boolean mask of the same shape as ``pos_map``.  May be all
        zeros if ``pos_map`` contains no positive values.
    quantile_used : Optional[float]
        The quantile threshold used to produce the mask.  ``None`` if
        no threshold was applied (e.g. ``pos_map`` contained no
        positive values).
    cov : float
        Fractional coverage of the resulting mask (between 0 and 1).
    """
    A = np.asarray(pos_map, dtype=np.float64)
    nz = A[A > 0]
    if nz.size == 0:
        return np.zeros_like(A, dtype=bool), None, 0.0
    # Clamp target bounds to sensible interval
    low = max(1e-8, float(target_low))
    high = min(1.0, float(target_high))
    q_lo, q_hi = 0.90, 0.9995  # initial search bracket
    cov = 0.0; thr_used = None
    for _ in range(int(max_iters)):
        q = 0.5 * (q_lo + q_hi)
        try:
            thr = np.quantile(nz, q)
        except Exception:
            thr = nz.max()
        M = (A >= thr)
        cov = M.mean()
        if cov > high:
            # threshold too low, mask too large -> increase q
            q_lo = q
        elif cov < low:
            # threshold too high, mask too small -> decrease q
            q_hi = q
        else:
            thr_used = q
            return M, thr_used, cov
    # If we exit without hitting the band, choose closest bound
    # Compute both endpoints and pick the one closer to midpoint
    thr_lo = np.quantile(nz, q_lo);
    M_lo = (A >= thr_lo);
    cov_lo = M_lo.mean()
    thr_hi = np.quantile(nz, q_hi);
    M_hi = (A >= thr_hi);
    cov_hi = M_hi.mean()
    mid = 0.5 * (low + high)
    if abs(cov_lo - mid) < abs(cov_hi - mid):
        return M_lo, q_lo, cov_lo
    else:
        return M_hi, q_hi, cov_hi

__all__ = ["mask_with_target_coverage"]